import mxnet as mx
import numpy as np
import argparse
import json
from mxnet.contrib import onnx as onnx_mxnet
from onnx import checker
import onnx
import logging
logging.basicConfig(level=logging.INFO)

def onnx2mxnet(onnx_file, inst_file):
    print("onnx_file:%s"%(onnx_file))
    # Imagenet输入
    input_shape = (1,3,112,112)

    sym, arg_params, aux_params = onnx_mxnet.import_model(onnx_file)

    #print(arg.keys())
    #print("\n-------------------\n")
    #print(aux.keys())

    #mx.viz.plot_network(sym, node_attrs={"shapre":"oval", "fixedsize":"false"})

    #mod = mx.mod.Module(symbol=sym,context=mx.gpu(),data_names=['data'],label_names=['label_test'])
    #mod.bind(for_training=False,data_shapes=[('data', input_shape)])
    #mod.set_params(arg_params,aux_params)
    #mod.save_checkpoint()
    mx.model.save_checkpoint(inst_file, 200, sym, arg_params, aux_params)

def make_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default="./net_inst.txt", type=str, required=True, help='config file for inference')
    return parser

if __name__ == "__main__":
    myparser = make_parser()
    args = myparser.parse_args()
    config_file = args.config
    file_config = open(args.config).read()
    lines_config = file_config.split("\n")
    for i in range(0,len(lines_config)):
        if "onnx_file" in lines_config[i]:
            onnx_file=lines_config[i].split("=")[-1]
        elif "inst_file" in lines_config[i]:
            inst_file = lines_config[i].split("=")[-1]

    onnx2mxnet(onnx_file, inst_file)
